import re
import os
import math
import piper
import threading
from typing import Iterable

_g_piper_engines = {}
_g_piper_engines_lock = threading.Lock()

def fix_chinese_english_spacing(text, space=' '):
    text = re.sub(r'([a-zA-Z0-9])([\u4e00-\u9fff])', r'\1' + space + r'\2', text)
    text = re.sub(r'([\u4e00-\u9fff])([a-zA-Z0-9])', r'\1' + space + r'\2', text)
    return text

def detect_language(text: str, default_lang : str = 'zh', threshold : float = 0.5) -> str:
    # zh or en?
    scores = {
        'zh': 0.0,
        'jp': 0.0,
        'en': 0.0,
    }
    for c in text:
        if 0x4e00 <= ord(c) <= 0x9fff:
            scores['zh'] += 1
        elif 0x3040 <= ord(c) <= 0x30ff:
            scores['jp'] += 1.5
        elif c.isalpha():
            scores['en'] += 0.1
    threshold += 1
    for k, v in scores.items():
        if all(v > threshold * v2 for k2, v2 in scores.items() if k2 != k):
            return k
    return default_lang

def convert_japanese_romaji(text: str, mode: str = 'romaji') -> str:
    # 定义罗马音和对应的假名
    hiragana = 'あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもらりるれろわがぎぐげござじずぜぞだぢづでどばびぶべぼぱぴぷぺぽん'
    katakana = 'アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモラリルレロワガギグゲゴザジズゼゾダヂヅデドバビブベボパピプペポン'
    # 创建转换表
    translation_table = str.maketrans(katakana, hiragana)
    # 转换字符串中的罗马音为假名
    text = text.translate(translation_table)
    if mode == 'romaji':
        lookup_table = {
            'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
            'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
            'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
            'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
            'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
            'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
            'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
            'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
            'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
            'わ': 'wa', 'を': 'wo', 'ん': 'n',
            'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
            'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
            'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
            'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
            'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
            'ん': 'n',
        }
    else:
        assert mode == 'hanzi'
        lookup_table = {
            'あ': '阿', 'い': '伊', 'う': '物', 'え': '爱', 'お': '欧',
            'か': '加', 'き': '几', 'く': '苦', 'け': '介', 'こ': '口',
            'さ': '左', 'し': '之', 'す': '寸', 'せ': '世', 'そ': '曾',
            'た': '塔', 'ち': '七', 'つ': '磁', 'て': '天', 'と': '投',
            'な': '那', 'に': '你', 'ぬ': '奴', 'ね': '内', 'の': '诺',
            'は': '哈', 'ひ': '比', 'ふ': '不', 'へ': '部', 'ほ': '保',
            'ま': '玛', 'み': '米', 'む': '木', 'め': '美', 'も': '莫',
            'ら': '拉', 'り': '利', 'る': '露', 'れ': '类', 'ろ': '搂',
            'や': '呀', 'ゆ': '由', 'よ': '哟',
            'わ': '瓦', 'を': '卧', 'ん': '恩',
            'が': '我', 'ぎ': '义', 'ぐ': '古', 'げ': '外', 'ご': '钩',
            'ざ': '座', 'じ': '极', 'ず': '字', 'ぜ': '是', 'ぞ': '奏',
            'だ': '达', 'ぢ': '及', 'づ': '组', 'で': '得', 'ど': '都',
            'ば': '八', 'び': '比', 'ぶ': '补', 'べ': '被', 'ぼ': '波',
            'ぱ': '爬', 'ぴ': '披', 'ぷ': '铺', 'ぺ': '陪', 'ぽ': '破',
            'ん': '嗯',
        }
    # 转换字符串中的假名为罗马音
    converted_text = ''
    for c in text:
        if c in lookup_table:
            converted_text += lookup_table[c]
        else:
            converted_text += c
    return converted_text

def split_at_puncation_marks(text: str) -> list[str]:
    # 定义全角标点和对应的半角标点
    punctuations = '，、。！？；：“”‘’（）【】{}《》～,,!?;:""\'\'()[]{}<>~'
    lines = re.split('[\\n' + re.escape(punctuations) + ']|\\.\\s', text)
    # lines = re.split(r'[\n！？。]|[!\?\.]\s', text)
    return [line for line in (line.strip() for line in lines) if line]

def convert_text_readable(text: str) -> list[str]:
    md_link_pattern = r'\[([^\]]+)\]\([^\)]+\)'
    url_pattern = r'https?://([a-zA-Z0-9\-_\.]+)(:[0-9]+)?(/[a-zA-Z0-9\-_\./%]*)?'
    code_block_pattern = r'```(\w*)\n(\n|.)*?\n```'
    text = re.sub(url_pattern, r'\1', text)
    text = re.sub(md_link_pattern, r'\1', text)
    text = re.sub(code_block_pattern, r'[\1 code block]', text)
    text = text.replace('**', '')
    text = text.replace('tten', 'ttin')  # bug of piper?
    text = convert_japanese_romaji(text, mode='hanzi')
    text = fix_chinese_english_spacing(text)
    return split_at_puncation_marks(text)

download_path = os.path.join(os.path.dirname(__file__), 'models')

def get_lang_engine(lang: str, quality: str = 'low') -> piper.PiperVoice | None:
    if lang not in _g_piper_engines:
        with _g_piper_engines_lock:
            if lang not in _g_piper_engines:
                path_lut = {
                    'medium': {
                        'zh': os.path.join(download_path, 'zh_CN-huayan-medium.onnx'),
                        'en': os.path.join(download_path, 'en_US-hfc_female-medium.onnx'),
                    },
                    'low': {
                        'zh': os.path.join(download_path, 'zh_CN-huayan-x_low.onnx'),
                        'en': os.path.join(download_path, 'en_US-danny-low.onnx'),
                    },
                }
                if quality not in path_lut or lang not in path_lut[quality]:
                    engine = None
                else:
                    engine = piper.PiperVoice.load(path_lut[quality][lang], use_cuda=False)
                _g_piper_engines[lang] = engine
    return _g_piper_engines[lang]

def speech_generator(lines: list[str], speed: float, silence: float, rate: int, quality: str) -> Iterable[bytes]:
    lang = 'zh'
    for line in lines:
        line = line.strip()
        if not line:
            continue
        lang = detect_language(line, lang)
        engine = get_lang_engine(lang, quality)
        if engine:
            stream = engine.synthesize_stream_raw(line, length_scale=1 / speed, sentence_silence=silence)
            if rate == engine.config.sample_rate:
                yield from stream
            else:
                import audioop
                state = None
                for chunk in stream:
                    chunk, state = audioop.ratecv(chunk, 2, 1, engine.config.sample_rate, rate, state)
                    yield chunk

def adjust_volume(gen: Iterable[bytes], volume: float) -> Iterable[bytes]:
    import audioop
    for chunk in gen:
        chunk = audioop.mul(chunk, 2, volume)
        yield chunk

def adjust_pitch(gen: Iterable[bytes], pitch: float, rate: int) -> Iterable[bytes]:
    import audioop
    state = None
    pitch = math.exp2(pitch / 5)
    for chunk in gen:
        chunk, state = audioop.ratecv(chunk, 2, 1, rate, round(rate / pitch), state)
        yield chunk

def text_to_speech(text: str, speed: float = 1.0, volume: float = 1.0, pitch: float = 0.0, silence: float = 0.0, rate: int = 0, quality: str = 'medium') -> tuple[Iterable[bytes], int]:
    lines = convert_text_readable(text)
    volume = min(1, max(0, volume))
    pitch = min(3, max(-3, pitch))
    speed /= math.exp2(pitch / 5)
    speed = min(4, max(0.25, speed))
    silence = min(2, max(0, silence))
    if rate == 0:
        if quality == 'low':
            rate = 16000
        else:
            rate = 22050
    if quality == 'high':
        quality = 'medium'  # high quality not available yet
    rate = min(22050, max(8000, rate))
    gen = speech_generator(lines, speed, silence, rate, quality)
    if volume != 1:
        gen = adjust_volume(gen, volume)
    if pitch != 0:
        gen = adjust_pitch(gen, pitch, rate)
    return gen, rate

if __name__ == '__main__':
    import sounddevice
    import numpy as np
    import time
    t0 = time.time()
    # stream, rate = text_to_speech('I have written a Python implementation of Radix Sort for you. The code snippet is displayed in a separate window. If you need any further assistance or explanation, feel free to ask!')
    stream, rate = text_to_speech('注意看, 这个男人叫里该隐.')
    for chunk in stream:
        print(time.time() - t0)
        sounddevice.play(np.frombuffer(chunk, dtype=np.int16) // 3, rate)
        sounddevice.wait()
        t0 = time.time()
